
//::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
/** @author  John Miller
 *  @version 1.1
 *  @date    Sun Sep 29 18:56:17 EDT 2013
 *  @see     LICENSE (MIT style license file).
 *  @see     http://www.cs.iastate.edu/~honavar/smo-svm.pdf
 */

package scalation.minima

import math.abs

import scalation.linalgebra.{MatrixD, VectorD}
import scalation.linalgebra_gen.VectorN
import scalation.linalgebra_gen.Vectors.VectorI
import scalation.random.Random
import scalation.util.Error

//::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
/** The `SeqMinOpt` class is a translation of the C++ code found at the above URL
 *  into Scala and includes a few simplifications (e.g., currently only works for
 *  linear kernels, dense data and binary classification.
 *  @param x  the training data points stored as rows in a maxtrix
 *  @param y  the classification of the data points stored in a vector
 */
class SeqMinOpt (x: MatrixD, y: VectorI)
      extends Error
//    extends Minimizer with Error
{
    type Pair = Tuple2 [Double, Double]

    private val DEBUG   = true               // debug flag
    private val EPSILON = 1E-3               // a number close to zero
    private val TOL     = 1E-3               // tolerance level
    private val C       = 0.05               // crossing penalty
    private val m       = x.dim1             // number of rows (data points)
    private val n       = x.dim2             // number of columns (variables)
    private val ESI     = m                  // support vectors in range 0..ESI
    private val alp     = new VectorD (m)    // alpha (Langrange multipliers)
    private val eCache  = new VectorD (m)    // error cache
    private val w       = new VectorD (n)    // weights

    private var al1     = 0.0                // old Langrange multiplier 1
    private var a1      = 0.0                // new Langrange multiplier 1
    private var al2     = 0.0                // old Langrange multiplier 2
    private var a2      = 0.0                // new Langrange multiplier 2
    private var b       = 0.0                // threshold
    private var delta_b = 0.0                // change in threshold

    private val rn = new Random ()           // random number generator

    //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    /** Compute the value using the learned function (assumes linear, dense).
     *  @param k  the row index into the data matrix
     */
    def func (k: Int): Double = (w dot x(k)) - b

    //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    /** Compute the value using the kernel function (assumes linear, dense).
     *  @param i1  the first row index into the data matrix
     *  @param i2  the second row index into the data matrix
     */
    def kernel_func (i1: Int, i2: Int): Double = x(i1) dot x(i2)

    //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    /** Optimize by replacing old values of Lagrange multipliers al1, al2 with
     *  new values a1 and a2.
     *  @param i1  the index for the first Lagrange multipliers (alpha) 
     *  @param i2  the index for the second Lagrange multipliers (alpha)
     */
    def takeStep (i1: Int, i2: Int): Boolean =
    {
        if (DEBUG) println ("takeStep (" + i1 + ", " + i2 + ")")
        if (i1 == i2) { println ("takeStep: skip if i1 == i2"); return false }

        // loop up al1, y1, e1, al2, y2, e2
        al1 = alp(i1)                       // old multiplier 1
        val y1  = y(i1)                     // target value 1
        val e1  = if (al1 > 0.0 && al1 < C) eCache(i1) else func (i1) - y1
        al2 = alp(i2)                       // old multiplier 2
        val y2  = y(i2)                     // target value 1
        val e2  = if (al2 > 0.0 && al2 < C) eCache(i2) else func (i2) - y2

        val s = y1 * y2
        val (l, h) = computeLH (y1, y2)

        if (l == h) { println ("takeStep: skip if l == h"); return false }

        // compute eta
        val k11 = kernel_func (i1, i1)
        val k12 = kernel_func (i1, i2)
        val k22 = kernel_func (i2, i2)
        val eta = 2.0 * k12 - (k11 + k22)

        if (eta < 0.0) {
            a2 = al2 + y2 * (e2 - e1) / eta
            if (a2 < l) a2 = l else if (a2 > h) a2 = h
        } else {
            val c1 = eta / 2.0
            val c2 = y2 * (e1 - e2) - eta * al2
            val lObj = c1 * l * l + c2 * l
            val hObj = c1 * h * h + c2 * h
            a2 = if (lObj > hObj + EPSILON) l else if (lObj < hObj - EPSILON) h else al2
        } // if

        if (abs (a2 - al2) < EPSILON * (a2 + al2 + EPSILON)) {
            println ("takeStep: skip if a2 = " + a2 + " ~=  al2 = " + al2)      // almost no change
            return false
        } // if

        a1 = al1 - s * (a2 - al2)
        if (a1 < 0.0) {
            a2 += s * a1; a1 = 0
        } else if (a1 > C) {
            val t = a1 - C; a2 += s * t; a1 = C
        } // if

        update_threshold (y1, y2, e1, e2, k11, k12, k22)
        update (i1, i2, y1, y2)                        // weights and eCache
    
        alp(i1) = a1; alp(i2) = a2                     // store a1, a2 in alp array
        return true
    } // takeStep

    //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    /** Compute 'l' and 'h'.
     *  @param y1  the first target value
     *  @param y2  the second target value
     */
    def computeLH (y1: Int, y2: Int): Pair =
    {
        if (y1 == y2) {
            val gamma = al1 + al2
            return if (gamma > C) (gamma - C, C) else (0.0, gamma)
        } else {
            val gamma = al1 - al2
            return if (gamma > 0.0) (0.0, C - gamma) else (-gamma, C )
        } // if
    } // computeLH 

    //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    /** Update the threshold 'b'.
     *  @param y1   the first target value
     *  @param y2   the second target value
     *  @param e1   the error for i1 
     *  @param e2   the error for i2
     *  @param k11  the value of the kernel function at i1, i1
     *  @param k12  the value of the kernel function at i1, i2
     *  @param k22  the value of the kernel function at i2, i2
     */
    def update_threshold (y1: Int, y2: Int, e1: Double, e2: Double, k11: Double, k12: Double, k22: Double)
    {
        val b1 = b + e1 + y1 * (a1 - al1) * k11 + y2 * (a2 - al2) * k12
        val b2 = b + e2 + y1 * (a1 - al1) * k12 + y2 * (a2 - al2) * k22
        val bnew = if (a1 > 0.0 && a1 < C) b1
              else if (a2 > 0.0 && a2 < C) b2
              else                        (b1 + b2) / 2.0
        delta_b = bnew - b
        b = bnew
    } // update_threshold

    //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    /** Update weights 'w' and error cache 'eCache'.
     *  @param i1  the index for the first Lagrange multipliers (alpha) 
     *  @param i2  the index for the second Lagrange multipliers (alpha)
     *  @param y1  the first target value
     *  @param y2  the second target value
     */
    def update (i1: Int, i2: Int, y1: Int, y2: Int)
    {
        val t1 = y1 * (a1 - al1)
        val t2 = y2 * (a2 - al2)
        for (j <- 0 until n) w(j) += x(i1, j) * t1 + x(i2, j) * t2

        for (i <- 0 until ESI if 0.0 < alp(i) && alp(i) < C) {
            eCache(i) += t1 * kernel_func (i1, i) + t2 * kernel_func (i2, i) - delta_b
        } // for
        eCache(i1) = 0.0
        eCache(i2) = 0.0
    } // update

    //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    /** Check whether the first alpha 'alp(i1)' violates the KKT conditions by
     *  more than 'TOL', if so look for the second alpha 'apl(i2)' and jointly
     *  optimize the two alphas (Lagrange multipliers).
     *  @param i1  the index for the first Lagrange multipliers (alpha) 
     */
    def checkExample (i1: Int): Boolean =
    {
        if (DEBUG) println ("ckeckExample (" + i1 + ")")
        val y1 = y(i1)
        al1    = alp(i1)

        val e1 = if (al1 > 0.0 && al1 < C) eCache(i1) else func (i1) - y1

        val r1 = y1 * e1
        if (r1 < -TOL && al1 < C || r1 > TOL && al1 > 0.0) {   // try to find i2
            if (try_i2_a (i1, e1)) return true
            if (try_i2_b (i1))     return true
            if (try_i2_c (i1))     return true
        } // if
        if (DEBUG) println ("checkExample: returning false")
        return false
    } // checkExample

    //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    /** Try to find 'i2' by maximizing the difference between e1 and e2.
     *  @param i1  the index for the first Lagrange multipliers (alpha) 
     *  @param e1  the error for i1
     */
    def try_i2_a (i1: Int, e1: Double): Boolean =
    {
        if (DEBUG) println ("try_i2_a (" + i1 + ", " + e1 + ")")
        var i2   = -1
        var dmax = 0.0
        for (k <- 0 until ESI if alp(k) > 0.0 && alp(k) < C) {
            val e2   = eCache(k)
            val diff = abs (e1 - e2)
            if (diff > dmax) { dmax = diff; i2 = k }
        } // for
        if (i2 >= 0 && takeStep (i1, i2)) return true
        false
    } // try_i2_a

    //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    /** Try to find 'i2' by iterating through the non-bound examples.
     *  @param i1  the index for the first Lagrange multipliers (alpha) 
     */
    def try_i2_b (i1: Int): Boolean =
    {
        if (DEBUG) println ("try_i2_b (" + i1 + ")")
        var i2 = -1
        val k0 = (rn.gen * ESI).toInt
        for (k <- k0 until ESI + k0) {
            i2 = k % ESI
            if (alp(i2) > 0.0 && alp(i2) < C && takeStep (i1, i2)) return true
        } // for
        false
    } // try_i2_b

    //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    /** Try to find '2' by iterating through the entire training set.
     *  @param i1  the index for the first Lagrange multipliers (alpha) 
     */
    def try_i2_c (i1: Int): Boolean =
    {
        if (DEBUG) println ("try_i2_c (" + i1 + ")")
        var i2 = -1
        val k0 = (rn.gen * ESI).toInt
        for (k <- k0 until ESI + k0) {
            i2 = k % ESI
            if (takeStep (i1, i2)) return true
        } // for
        false
    } // try_i2_c

    //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    /** Solve the SMO optimization problem for the weight vector 'w' and the 
     *  threshold 'b' for the model '(w dot z) - b'.
     */
    def solve (): Tuple2 [VectorD, Double] =
    {
        var nChanged = 0
        var checkAll = true

        while (nChanged > 0 || checkAll) {
            nChanged = 0
            if (checkAll) {
                for (k <- 0 until m if checkExample (k)) nChanged += 1
            } else {
                for (k <- 0 until m if alp(k) != 0.0 && alp(k) != C && checkExample (k)) nChanged += 1
            } // if
            if (checkAll) checkAll = false
            else if (nChanged == 0 ) checkAll = true
            diagnose ()
        } // while
        (w, b)
    } // solve

    //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
    /** Print diagnostic information.
     */
    def diagnose ()
    {
        // FIX - to be implemented
    } // diagnose

} // SeqMinOpt class

/* Modification 2:
       while (nChanges > 0 || checkAll) {
            nChanged = 0
            if (checkAll) {
                for (k <- 0 until m) nChanged += checkExample (k)
            } else {
                var success = true
                while (b_up <= b_lo - 2 * TOL && success) {
                    var i2  = i_lo
                    var y2  = y(i2)
                    var al2 = langrange (i2)
                    var f2  = fc(i2)
                    success = takeStep (i_up, i_lo)
                    nChanged += success
                } // while
                nChanged = 0
            } // if
            if (checkAll) checkAll = false
            else if (nChanged == 0 ) checkAll = true
        } // while
*/


//::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
/** The `SeqMinOptTest` is used to test the `SeqMinOpt` class.
 */
object SeqMinOptTest extends App
{
    val x = new MatrixD ((4, 2), 1.0, 2.0,    // 4 data points
                                 2.0, 1.0,
                                 2.0, 3.0,
                                 3.0, 2.0)
    val y = VectorN (-1, -1, 1, 1)            // classification of points

    val smo   = new SeqMinOpt (x, y)          // create optimizer
    val model = smo.solve ()
    println ("model w = " + model._1)
    println ("model b = " + model._2)

} // SeqMinOptTest object

